include("/Users/urielyang/OneDrive - Emory University/Honors/SIRQ.jl")

using XLSX, DataFrames, Plots, JLD

df = DataFrame(XLSX.readdata(
    "/Users/urielyang/OneDrive - Emory University/Honors/Data_Italy.xlsx",
    "Sheet!A1:I81",
))
#China [0.90 0.023] [0.86 0.024]
#Italy [0.8 0.03] [1.0 0.036]
#Italy train 50days better result stepnorm 0.03
#sk 0.7 0.03 0.71 0.01

final_data = df[2:end, [5, 6]]
x = convert(Array{Float64,1}, final_data[:, 1])
y = convert(Array{Float64,1}, final_data[:, 2])
u0 = [52e6, 500.0, 10.0, 10.0]
init_params = [0.8, 0.03]
total_t_span = [0.0, 79.0]
training_t_span = [1, 50]
SIRQ.run([x, y], u0, init_params, total_t_span, training_t_span)
SIRQ.plot_states()
pred_SK = SIRQ.get_pred()
SIRQ.plot_quarantine()
q_SK = SIRQ.get_quarantine()

save("/Users/urielyang/OneDrive - Emory University/Honors/predres_SK.jld", "pred_SK", pred_SK, "q_WH", q_SK)
save("/Users/urielyang/OneDrive - Emory University/Honors/loss.jld", "loss_SK", loss_SK,
    "loss_WH", loss_WH, "loss_IT", loss_IT)
